{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0dbd2fb7-f99b-467c-82b8-d0ddc4a39c2e",
   "metadata": {},
   "source": [
    "# Phenotypic traits prediction with pretrained Bacformer tutorial\n",
    "\n",
    "This tutorial outlines how one can use genome embeddings from pretrained Bacformer to predict phenotypic labels.\n",
    "\n",
    "We provide a dataset containing precomputed genome embeddings and 139 diverse phenotypic labels. We show how to train and evaluate\n",
    "phenotype prediction from a genome embedding using a simple linear regression model.\n",
    "\n",
    "If your task is highly challenging or requires embedding genomes, we recommend following the `finetune_phenotypic_traits_prediction_tutorial.ipynb`, which outlines how to finetune entire\n",
    "Bacformer model for phenotypic traits prediction.\n",
    "\n",
    "Before you start, make sure to have the [datasets](https://huggingface.co/docs/datasets/en/installation), [scikit-learn](https://scikit-learn.org/stable/install.html) packages installed.\n",
    "\n",
    "## Step 1: Import required dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "fa61441a-f10e-48ed-82a6-a0a66f8522bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from datasets import load_dataset\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.pipeline import Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dce18cee-3e03-4919-9adf-52cb183ed779",
   "metadata": {},
   "source": [
    "## Step 2: Load the dataset and subset to a selected phenotype\n",
    "\n",
    "We load the precomputed genome embeddings together with the phenotypic trait labels (139 of them).\n",
    "\n",
    "As an example, we select `gideon_Catalase` phenotype. The actual phenotype we are predicting here is the `Catalase`, which denotes whether a bacterium produces the catalase enzyme that breaks down hydrogen peroxide (H₂O₂) into water and oxygen, thereby protecting the cell from oxidative stress. The `gideon` stands for the source of the phenotype [1].\n",
    "\n",
    "The `gideon_Catalase` phenotype is a binary classification problem."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "a48df592-6935-4a70-9463-18530bd8490c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the dataset and convert to pandas DF\n",
    "df = load_dataset(\"macwiatrak/bacformer-genome-embeddings-with-phenotypic-traits-labels\", split=\"train\").to_pandas()\n",
    "\n",
    "# select the phenotype\n",
    "phenotype = \"gideon_Catalase\"\n",
    "# remove the genomes with NaN values for the phenotype of interest\n",
    "phenotype_df = df[df[phenotype].notna()].copy()\n",
    "\n",
    "# get features matrix X and the label vector Y\n",
    "X = np.vstack(phenotype_df[\"bacformer_genome_embedding\"].to_list())\n",
    "y = phenotype_df[phenotype].map({'+': 1, '-': 0}).values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56239aec-030a-472f-8895-69b6cb9fb4f5",
   "metadata": {},
   "source": [
    "## Step 3: Perform train / val / test split\n",
    "\n",
    "Perform stratified train, val, test splir with `60 / 20 / 20` ratio. We use the validation set for hyperparameter search.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "81cd12db-7a44-455c-9f63-4191d8ec3075",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ------------------------------------------------------------------\n",
    "# 2  60 / 20 / 20 stratified split (train → 0.6, val → 0.2, test → 0.2)\n",
    "# ------------------------------------------------------------------\n",
    "X_train_val, X_test, y_train_val, y_test = train_test_split(\n",
    "    X, y, test_size=0.20, random_state=42, stratify=y\n",
    ")\n",
    "X_train, X_val, y_train, y_val = train_test_split(\n",
    "    X_train_val, y_train_val,\n",
    "    test_size=0.25,  # 0.25 × 0.80 = 0.20\n",
    "    random_state=42,\n",
    "    stratify=y_train_val\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7071b8c-8a07-439d-a262-95ee015f655e",
   "metadata": {},
   "source": [
    "## Step 4: Perform hyperparameter search on the validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "576e52c2-e835-468b-b099-44d591f88738",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best C on validation: 0.1  |  AUROC_val = 0.9973\n"
     ]
    }
   ],
   "source": [
    "# ------------------------------------------------------------------\n",
    "# 3  Hyper-parameter search on validation set\n",
    "# ------------------------------------------------------------------\n",
    "param_grid = np.logspace(-4, 4, 9)      # 1e-4 … 1e4\n",
    "best_auc, best_C, best_model = -np.inf, None, None\n",
    "\n",
    "for C in param_grid:\n",
    "    model = Pipeline(\n",
    "        steps=[\n",
    "            (\"scale\", StandardScaler()),\n",
    "            (\"clf\", LogisticRegression(\n",
    "                C=C, solver=\"liblinear\", max_iter=2000, penalty=\"l2\"\n",
    "            ))\n",
    "        ]\n",
    "    )\n",
    "    model.fit(X_train, y_train)\n",
    "\n",
    "    val_probs = model.predict_proba(X_val)[:, 1]\n",
    "    auc = roc_auc_score(y_val, val_probs)\n",
    "\n",
    "    if auc > best_auc:\n",
    "        best_auc, best_C, best_model = auc, C, model\n",
    "\n",
    "print(f\"Best C on validation: {best_C}  |  AUROC_val = {best_auc:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff70a53-03ac-40d1-b5f7-bffe719417fa",
   "metadata": {},
   "source": [
    "## Step 5: Final evaluation on the held-out test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "b92ecf7f-2701-44de-9ec0-5d291d7d2840",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUROC_test = 0.9837\n"
     ]
    }
   ],
   "source": [
    "# ------------------------------------------------------------------\n",
    "# 4  Final evaluation on the held-out test set\n",
    "# ------------------------------------------------------------------\n",
    "test_probs = best_model.predict_proba(X_test)[:, 1]\n",
    "test_auc  = roc_auc_score(y_test, test_probs)\n",
    "\n",
    "print(f\"AUROC_test = {test_auc:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd14a5d1-d940-4ad5-bdd9-68686fd082b1",
   "metadata": {},
   "source": [
    "----------------------\n",
    "#### Voilà, you made it 👏! \n",
    "\n",
    "There is 139 phenotypic traits to choose from and experiment with!\n",
    "\n",
    "In case of any issues or questions raise an issue on github - https://github.com/macwiatrak/Bacformer/issues.\n",
    "\n",
    "If your task is highly challenging and using precomputed genome embedding yields not sufficient performance  or requires embedding genomes, we recommend following the `finetune_phenotypic_traits_prediction_tutorial.ipynb`, which outlines how to finetune entire\n",
    "Bacformer model for phenotypic traits prediction.\n",
    "\n",
    "## References\n",
    "\n",
    "[1] Weimann, Aaron, et al. \"From genomes to phenotypes: Traitar, the microbial trait analyzer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1119e159-342b-4694-b360-d6bd82a7c373",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
